-- Piotr Bober
-- Compiler Construction 2009/10
-- the compiler
module Compiler
   (compile, compileFun, generate, IntCommand(..), IntCode, LocalVariables)
   where

import Maszyna
import DataTypes
import Prelude hiding (negate, div, lookup)
import Char (ord)
import Data.Map (Map, empty, fromList, insert, lookup)
import Data.List (scanl)
import System.IO (Handle, hPutStrLn)

-- intermediate code representation
type IntCode = [IntCommand]

data IntCommand = Exit
                | SkipM
                | Store Double        -- store, double: value to be stored
                | Load                -- pop
                | Copy
                | Arithmetic String   -- string: add, sub, mul, div
                | Negate
                | Boolean String      -- string: <, <=, >, >=, ==, /=
                | Read' Int           -- explicitly transformed
                | Write Int           -- explicitly transformed
                | Jump Int            -- jump
                | JumpZero Int        -- jumpzero
                | JumpNonZero Int     -- jumpnonzero
                | Call String         -- call, string: function name
                | Ret
                | Input
                | Output
                | OutputChar
   deriving (Show, Eq)

type LocalVariables = Map Variable Int


compile :: Handle -> Program -> IO Code
compile fh prog =
   let
      (tmp, funs) = unzip $ map compileFun prog
      (names,lens) = unzip tmp
      adrs = scanl (+) 1 $ init lens
      m = fromList $ zip names adrs
   in case lookup "main" m of
         Nothing  -> error "compile: function \"main\" not found"
         Just adr -> do
                        hPutStrLn fh $ unlines $ map (unlines  .  map show) $ [Jump adr] : funs ++ [[Exit]]
                        return $ jump adr : concatMap (generate m) (zip adrs funs) ++ [exit]

countVars :: String -> Program -> Int
countVars _ [] = 0
countVars name ((FunDef name' args stmts):_) | name == name' =
   let
      foo [] = 0
      foo ((Init _ _):xs) = 1 + foo xs
      foo (_:xs) = foo xs
   in
      length args + foo stmts
countVars name (_:fs) = countVars name fs

generate :: Map String Int -> (Int, IntCode) -> Code
generate m (posStart, ic) = map (genCommand m) $ zip [posStart .. posStart + length ic - 1] ic

genCommand :: Map String Int -> (Int, IntCommand) -> Command
genCommand m (pos,ic) = case ic of
   Exit           -> exit
   SkipM          -> skip
   Store x        -> store x
   Load           -> load
   Copy           -> copy
   Arithmetic "+" -> add
   Arithmetic "-" -> sub
   Arithmetic "*" -> mul
   Arithmetic "/" -> div
   Arithmetic op  -> error $ "genCommand: unknown operator: " ++ op
   Negate         -> negate
   Boolean "<"    -> less
   Boolean "<="   -> lessequal
   Boolean ">"    -> greater
   Boolean ">="   -> greaterequal
   Boolean "=="   -> equal
   Boolean "!="   -> notequal
   Boolean op     -> error $ "genCommand: unknown operator: " ++ op
   Read' n        -> read' n
   Write n        -> write n
   Jump n         -> jump n
   JumpZero n     -> jumpzero n
   JumpNonZero n  -> jumpnonzero n
   Call name      -> case lookup name m of
                        Nothing  -> error $ "genCommand: function " ++ name ++ " not found"
                        Just sth -> call $ sth - pos
   Ret            -> ret
   Input          -> input
   Output         -> output
   OutputChar     -> outputchar

compileFun :: FunDef -> ((String, Int), IntCode)
compileFun f@(FunDef name args stmts) =
   let
      ic = concatMap (uncurry compileStmt $ readLocalVariables f) stmts
      n = length ic
      k = countVars name [f] - length args
      missingVars = replicate k $ Store 0.0
   in
      ((name, n + k), missingVars ++ ic)

readLocalVariables :: FunDef -> (Int, LocalVariables)
readLocalVariables (FunDef _ args stmts) =
   let
      n = length args
      base = foldr (\(a,b) y -> insert a b y) empty $ zip args [1..n]
      foo count m [] = (count, m)
      foo count m (x:xs) = case x of
         If _ c1 c2  -> foo count m (c1 ++ c2 ++ xs)
         If' _ c     -> foo count m (c ++ xs)
         While _ c   -> foo count m (c ++ xs)
         Do _ c      -> foo count m (c ++ xs)
         PrintStr _  -> foo count m xs
         Print _     -> foo count m xs
         Read _      -> foo count m xs
         Return _    -> foo count m xs
         Assign _ _  -> foo count m xs
         Init v _    -> foo (count + 1) (insert v (count + 1) m) xs
         Skip        -> foo count m xs
         Seq ss      -> foo count m (ss ++ xs)
   in
      foo n base stmts

compileStmt :: Int -> LocalVariables -> Statement -> IntCode
compileStmt varcount locals stmt = case stmt of
   If b c1 c2  -> let
                     ic1 = compileStmt varcount locals $ Seq c1
                     ic2 = compileStmt varcount locals $ Seq c2
                     n1  = length ic1
                     n2  = length ic2
                  in
                     evalB varcount locals b ++ (JumpZero (n1+3) : Load : ic1) ++ (Jump (n2+2) : Load : ic2)
   If' b c     -> let
                     ic = compileStmt varcount locals $ Seq c
                     n  = length ic
                  in
                     evalB varcount locals b ++ (JumpZero (n+3) : Load : ic) ++ [Jump 2, Load]
   While b c   -> let
                     ic = compileStmt varcount locals $ Seq c
                     ib = evalB varcount locals b
                     n  = length ic
                     m  = length ib
                  in
                     ib ++ (JumpZero (n+3) : Load : ic) ++ [Jump (-n-m-2), Load]
   Do b c      -> let
                     ic = compileStmt varcount locals $ Seq c
                     ib = evalB varcount locals b
                     n  = length ic
                     m  = length ib
                  in
                     Store 0.0 : Load : (ic ++ ib ++ [JumpNonZero (-m-n-1), Load])
   PrintStr s  -> concatMap (\c -> [Store $ fromIntegral $ ord c, OutputChar, Load]) s
   Print a     -> evalA varcount locals a ++ [Output, Load]
   Read v      -> case lookup v locals of
                    Nothing  -> error $ "compileStmt: variable " ++ v ++ " not found"
                    Just adr -> [Read' $ varcount - adr + 1, Load, Input, Write $ varcount - adr + 1]
   Return a    -> evalA varcount locals a ++ Write (varcount + 1) : replicate varcount Load ++ [Ret]
   Assign v a  -> case lookup v locals of
                    Nothing  -> error $ "compileStmt: variable " ++ v ++ " not found"
                    Just adr -> evalA varcount locals a ++ [(Read' $ varcount - adr + 2), Load, (Write $ varcount - adr + 1)]
   Init v a    -> case lookup v locals of
                    Nothing  -> error $ "compileStmt: variable " ++ v ++ " not found"
                    Just adr -> evalA varcount locals a ++ [(Read' $ varcount - adr + 2), Load, (Write $ varcount - adr + 1)]
   Skip        -> [SkipM]
   Seq xs      -> concatMap (compileStmt varcount locals) xs

evalA :: Int -> LocalVariables -> AExp -> IntCode
evalA varcount locals a = case a of
   Const x           -> [Store x]
   Var v             -> case lookup v locals of
                           Nothing  -> error $ "evalA: variable " ++ v ++ " not found"
                           Just adr -> [Read' $ varcount - adr + 1, Copy, Write $ varcount - adr + 2]
   FCall name args   -> putArgs varcount locals 0 args ++ [Call name]
   UnOp "-" a'       -> evalA varcount locals a' ++ [Negate]
   UnOp op _         -> error $ "evalA: unknown operator: " ++ op
   BinOp op a1 a2    -> evalA varcount locals a1 ++ Write (varcount + 1) : evalA varcount locals a2 ++ [Read' $ varcount + 2, Arithmetic op]

putArgs :: Int -> LocalVariables -> Int -> Args -> IntCode
putArgs varcount _ argcount [] = zipWith (\x y -> Read' $ x + y) [1..argcount] (replicate argcount varcount)
putArgs varcount locals argcount (x:xs) = evalA varcount locals x ++ Write (varcount + 1) : putArgs varcount locals (argcount + 1) xs

evalB :: Int -> LocalVariables -> BExp -> IntCode
evalB varcount locals b = case b of
   FALSE                -> [Store 0.0]
   TRUE                 -> [Store 1.0]
   Neg b'               -> evalB varcount locals b' ++ [JumpZero 4, Load, Store 0.0, Jump 3, Load, Store 1.0]
   BoolBinOp "&&" b1 b2 -> let
                             eb2 = evalB varcount locals b2
                             n   = length eb2
                          in
                             evalB varcount locals b1 ++ JumpZero (n + 2) : Load : eb2
   BoolBinOp "||" b1 b2 -> let
                             eb2 = evalB varcount locals b2
                             n   = length eb2
                          in
                             evalB varcount locals b1 ++ JumpNonZero (n + 2) : Load : eb2
   BoolBinOp op _ _     -> error $ "evalB: unknown operator: " ++ op
   Rel op a1 a2         -> evalA varcount locals a1 ++ Write (varcount + 1) : evalA varcount locals a2 ++ [Read' $ varcount + 2, Boolean op]
